/**
 * Copyright 2022-2023 Huawei Technologies Co., Ltd
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
#ifndef MINDSPORE_NNACL_FP32_POOLING_NEON_H_
#define MINDSPORE_NNACL_FP32_POOLING_NEON_H_

#include "nnacl/intrinsics/ms_simd_instructions.h"
#include "nnacl/intrinsics/ms_simd_neon_instructions.h"

#ifdef __cplusplus
extern "C" {
#endif

#define MS_SIMD_INSTRUCTION MS_SIMD_NEON_INSTRUCTION
#define BLOCK_NUM 4
#define MS_SIMD_NEON

static inline int AvgPoolingBatchNEON(int ci, const float *src_plane_ptr, int channel,
  float *dst_plane_ptr, int real_win_h_start, int real_win_h_end, int real_win_w_start, int real_win_w_end,
  int in_h_index, int in_w, int in_w_index, float minf, float maxf) {
  SIMD_F32 min_val = SIMD_MOV_F32(minf);
  SIMD_F32 max_val = SIMD_MOV_F32(maxf);
  for (int block_max_size = channel - BLOCK_NUM + 1; ci < block_max_size; ci += BLOCK_NUM) {
    const float *src_c_ptr = src_plane_ptr + ci;
    float *dst_c_ptr = dst_plane_ptr + ci;
    SIMD_F32 tmp_avg = SIMD_SET0_F32;
    int real_count = 0;
    for (int h = real_win_h_start; h < real_win_h_end; h++) {
      for (int w = real_win_w_start; w < real_win_w_end; w++) {
        const float *src_win_ptr = src_c_ptr + ((in_h_index + h) * in_w + in_w_index + w) * channel;
        tmp_avg = SIMD_ADD_F32(tmp_avg, SIMD_LD_F32(src_win_ptr));
        ++real_count;
      }
    }
    tmp_avg = SIMD_DIV_F32(tmp_avg, SIMD_MOV_F32(real_count));
    tmp_avg = SIMD_MAX_F32(tmp_avg, min_val);
    tmp_avg = SIMD_MIN_F32(tmp_avg, max_val);
    SIMD_ST_F32(dst_c_ptr, tmp_avg);
  }
  return ci;
}

static inline int MaxPoolingBatchNEON(int ci, const float *src_plane_ptr, int channel,
  float *dst_plane_ptr, int real_win_h_start, int real_win_h_end, int real_win_w_start, int real_win_w_end,
  int in_h_index, int in_w, int in_w_index, float minf, float maxf) {
  SIMD_F32 min_val = SIMD_MOV_F32(minf);
  SIMD_F32 max_val = SIMD_MOV_F32(maxf);
  for (int block_max_size = channel - BLOCK_NUM + 1; ci < block_max_size; ci += BLOCK_NUM) {
    const float *src_c_ptr = src_plane_ptr + ci;
    float *dst_c_ptr = dst_plane_ptr + ci;
    SIMD_F32 tmp_max = min_val;
    for (int kh = real_win_h_start; kh < real_win_h_end; kh++) {
      for (int kw = real_win_w_start; kw < real_win_w_end; kw++) {
        const float *src_win_ptr = src_c_ptr + ((in_h_index + kh) * in_w + in_w_index + kw) * channel;
        tmp_max = SIMD_MAX_F32(tmp_max, SIMD_LD_F32(src_win_ptr));
      }
    }
    tmp_max = SIMD_MIN_F32(tmp_max, max_val);
    SIMD_ST_F32(dst_c_ptr, tmp_max);
  }
  return ci;
}

static inline int MaxPooling3DNEON(int c_idx, const float *src_batch_ptr, int channel, float *out,
  int d_start, int d_end, int h_start, int h_end, int w_start, int w_end, int d_stride, int h_stride) {
  for (int block_max_size = channel - BLOCK_NUM + 1; c_idx < block_max_size; c_idx += BLOCK_NUM) {
    const float *src_c_ptr = src_batch_ptr + c_idx;
    float *dst_c_ptr = out + c_idx;
    SIMD_F32 tmp_max = SIMD_MOV_F32(-FLT_MAX);
    for (int dd = d_start; dd < d_end; ++dd) {
      for (int hh = h_start; hh < h_end; ++hh) {
        for (int ww = w_start; ww < w_end; ++ww) {
          const float *input = src_c_ptr + dd * d_stride + hh * h_stride + ww * channel;
          tmp_max = SIMD_MAX_F32(SIMD_LD_F32(input), tmp_max);
        }
      }
    }
    SIMD_ST_F32(dst_c_ptr, tmp_max);
  }
  return c_idx;
}

static inline int AvgPooling3DNEON(int c_idx, const float *src_batch_ptr, int channel, float *out,
  int d_start, int d_end, int h_start, int h_end, int w_start, int w_end, int d_stride, int h_stride, int divisor) {
  for (int block_max_size = channel - BLOCK_NUM + 1; c_idx < block_max_size; c_idx += BLOCK_NUM) {
    const float *src_c_ptr = src_batch_ptr + c_idx;
    float *dst_c_ptr = out + c_idx;
    SIMD_F32 tmp_avg = SIMD_SET0_F32;
    for (int dd = d_start; dd < d_end; ++dd) {
      for (int hh = h_start; hh < h_end; ++hh) {
        for (int ww = w_start; ww < w_end; ++ww) {
          const float *input = src_c_ptr + dd * d_stride + hh * h_stride + ww * channel;
          tmp_avg = SIMD_ADD_F32(SIMD_LD_F32(input), tmp_avg);
        }
      }
    }
    tmp_avg = SIMD_DIV_F32(tmp_avg, SIMD_MOV_F32(divisor));
    SIMD_ST_F32(dst_c_ptr, tmp_avg);
  }
  return c_idx;
}

#undef MS_SIMD_INSTRUCTION
#undef BLOCK_NUM

#undef MS_SIMD_NEON
#ifdef __cplusplus
}
#endif
#endif
